import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.preprocessing import sequence
from keras.layers import Input, Embedding, Dense, LSTM, Dot, merge
from keras.models import Model
from keras.layers.core import Dense, Dropout, RepeatVector, Activation, Flatten, Permute, Lambda, Reshape
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import RMSprop, Adam
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from keras.layers.wrappers import TimeDistributed
from keras.constraints import maxnorm

# set GPU memory
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1
set_session(tf.Session(config=config))

VOCAB_SIZE = 2000

EMBED_SIZE = 64
HIDDEN_SIZE_g = 128
HIDDEN_SIZE_l = 128
HIDDEN_SIZE = 128
BATCH_SIZE = 32
EPOCHS = 150
MAX_LEN = 150


def gan_att_model(x_train, y_train, g_list, l_list, ng, gl):
    input_layer = Input(shape=(MAX_LEN,), dtype='int32')

    emb = Embedding(output_dim=EMBED_SIZE, input_dim=VOCAB_SIZE,
                    input_length=MAX_LEN, dropout=0.2)(input_layer)
    lstm_layer = LSTM(output_dim=HIDDEN_SIZE, return_sequences=True)(emb)

    # discriminator_gender
    opt = Adam(lr=0.0001)
    lstm_layer_dg = LSTM(output_dim=HIDDEN_SIZE_g, return_sequences=True)(lstm_layer)
    lstm_layer_dgatt = attention(lstm_layer_dg, lstm_layer_dg)
    out_dg_list = []
    for i in range(ng):
        out_dg_list.append(Dense(1, activation='sigmoid')(lstm_layer_dgatt))

    model_dg = Model(input=input_layer, output=out_dg_list)
    model_dg.compile(loss='binary_crossentropy', optimizer=opt)

    model_dg.trainable = False

    # discriminator_location
    lstm_layer_dl = LSTM(output_dim=HIDDEN_SIZE_l, return_sequences=True)(lstm_layer)
    lstm_layer_dlatt = attention(lstm_layer_dl, lstm_layer_dl)
    out_dl_list = []
    for i in range(gl):
        out_dl_list.append(Dense(1, activation='sigmoid')(lstm_layer_dlatt))

    model_dl = Model(input=input_layer, output=out_dl_list)
    model_dl.compile(loss='binary_crossentropy', optimizer=opt)

    model_dl.trainable = False


    # generator (predictor)
    lstm_layer_g = LSTM(output_dim=HIDDEN_SIZE, return_sequences=True)(lstm_layer)
    lstm_layer_dl = attention(lstm_layer_g, lstm_layer_dl)
    lstm_layer_dg = attention(lstm_layer_g, lstm_layer_dg)
    lstm_layer_g = merge([lstm_layer_dl, lstm_layer_dg], "sum")
    out_g = []
    for i in range(5):
        out_g.append(Dense(1, activation='sigmoid')(lstm_layer_g))

    model_g = Model(input=input_layer, output=out_g)
    model_g.compile(loss='binary_crossentropy', optimizer=opt)


    # combine
    combined = Model(input=input_layer, output=out_g + out_dg_list + out_dl_list)
    combined.compile(loss='binary_crossentropy', optimizer=opt)

    # train
    for epoch in range(EPOCHS):
        # Train Discriminator
        loss_dg = model_dg.fit(x_train, g_list, batch_size=BATCH_SIZE, epochs=1).history['loss']
        loss_dl = model_dl.fit(x_train, l_list, batch_size=BATCH_SIZE, epochs=1).history['loss']

        # Train the generator
        loss_g = combined.fit(x_train, y_train + g_list + l_list, batch_size=BATCH_SIZE, epochs=1).history['loss']

        # Plot the progress
        print ('%d [D1 loss: %f] [D1 loss: %f] [G loss: %f]' % (epoch, loss_dg[0], loss_dl[0], loss_g[0]))

    return combined


def attention(input_matrix1, input_matrix2):
    M_matrix = input_matrix1
    dense_a = TimeDistributed(Dense(1))(M_matrix)

    dense_a = Lambda(lambda x: x, output_shape=lambda s: s)(dense_a)
    dense_a = Reshape((-1, MAX_LEN))(dense_a)

    dense_a = Activation(activation="softmax")(dense_a)

    dense_r = Dot(axes=[2, 1])([dense_a, input_matrix2])
    attention_represention = Flatten()(dense_r)

    return attention_represention


def evaluate(pre, real):
    l = len(real)
    s = 0
    for k in range(5):
        TP = 0
        FP = 0
        FN = 0
        for i in range(l):
            if (real[i][k] == 1) and (pre[k][i][0] >= 0.5):
                TP += 1
            if real[i][k] == 1:
                FN += 1
            if pre[k][i][0] >= 0.5:
                FP += 1
        if FP > 0:
            acc = TP * 1.0 / FP
        else:
            acc = 0
        rec = TP * 1.0 / FN
        if acc + rec > 0:
            F1 = 2 * rec * acc * 1.0 / (acc + rec)
        else:
            F1 = 0
        print acc, rec, F1
        s += F1
    print s * 1.0 / 5


def learning(x_train, y_train, x_test, y_test, g_list, l_list, ng, nl):

    x_train = sequence.pad_sequences(x_train, maxlen=MAX_LEN)
    x_test = sequence.pad_sequences(x_test, maxlen=MAX_LEN)
    y_train = [np.array(y) for y in y_train]
    g_list = [np.array(c) for c in g_list]
    l_list = [np.array(c) for c in l_list]

    model = gan_att_model(x_train, y_train, g_list, l_list, ng, nl)
    pre = model.predict(x_test)

    evaluate(pre, y_test)